/**
 * Copyright 2019-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "minddata/mindrecord/include/shard_reader.h"

#include <algorithm>
#include <thread>

#include "utils/file_utils.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "utils/ms_utils.h"

namespace mindspore {
namespace mindrecord {
template <class Type>
// convert the string to exactly number type (int32_t/int64_t/float/double)
Type StringToNum(const std::string &str) {
  std::istringstream iss(str);
  Type num;
  iss >> num;
  return num;
}

ShardReader::ShardReader()
    : header_size_(0),
      page_size_(0),
      shard_count_(0),
      n_consumer_(0),
      num_padded_(0),
      num_rows_(0),
      total_blob_size_(0),
      sample_id_position_(0),
      deliver_id_(0),
      load_mode_(LoadMode::kFast),
      shard_sample_count_() {}

Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
                            std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(addresses_ptr);
  RETURN_IF_NOT_OK_MR(CheckFile(file_path));
  std::shared_ptr<json> header_ptr;
  RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(file_path, &header_ptr));

  *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]},
                    {"version", (*header_ptr)["version"]},         {"index_fields", (*header_ptr)["index_fields"]},
                    {"schema", (*header_ptr)["schema"]},           {"blob_fields", (*header_ptr)["blob_fields"]}};
  std::vector<std::string> addresses_vec = (*header_ptr)["shard_addresses"];
  *addresses_ptr = std::make_shared<std::vector<std::string>>(addresses_vec);
  return Status::OK();
}

Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
  std::string file_path = file_paths[0];
  auto first_meta_data_ptr = std::make_shared<json>();
  std::shared_ptr<std::vector<std::string>> addresses_ptr;
  RETURN_IF_NOT_OK_MR(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr));
  if (file_paths.size() == 1 && load_dataset == true) {
    auto ds = std::make_shared<std::vector<std::string>>();
    RETURN_IF_NOT_OK_MR(GetDatasetFiles(file_path, *addresses_ptr, &ds));
    file_paths_ = *ds;  // load files according to shard_addresses
  } else if (file_paths.size() >= 1 && load_dataset == false) {
    file_paths_ = file_paths;  // load files according to the input
  } else {
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] The values of 'load_dataset' and 'file_paths' are not as expected.");
  }
  for (const auto &file : file_paths_) {
    auto meta_data_ptr = std::make_shared<json>();
    RETURN_IF_NOT_OK_MR(GetMeta(file, meta_data_ptr, &addresses_ptr));
    CHECK_FAIL_RETURN_UNEXPECTED_MR(
      *meta_data_ptr == *first_meta_data_ptr,
      "Invalid file, the metadata of mindrecord file: " + file +
        " is different from others, please make sure all the mindrecord files generated by the same script.");
    sqlite3 *db = nullptr;
    RETURN_IF_NOT_OK_MR(VerifyDataset(&db, file));
    database_paths_.push_back(db);
  }
  ShardHeader sh = ShardHeader();
  RETURN_IF_NOT_OK_MR(sh.BuildDataset(file_paths_, load_dataset));
  shard_header_ = std::make_shared<ShardHeader>(sh);
  header_size_ = shard_header_->GetHeaderSize();
  page_size_ = shard_header_->GetPageSize();
  // version < 3.0
  if ((*first_meta_data_ptr)["version"] < kVersion) {
    shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
  } else {
    shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
  }
  num_rows_ = 0;
  auto row_group_summary = ReadRowGroupSummary();

  // clear the shard_sample_count_, because it will be insert when Launch func
  shard_sample_count_.clear();

  constexpr int64_t get_index = 3;
  for (const auto &rg : row_group_summary) {
    num_rows_ += std::get<get_index>(rg);
  }

  if (num_rows_ > SLOW_LOAD_THRESHOLD) {
    load_mode_ = LoadMode::kSlow;
    tasks_.load_mode_ = LoadMode::kSlow;
    MS_LOG(INFO) << "The number of samples is larger than " << SLOW_LOAD_THRESHOLD
                 << ", enable slow load mode. If you want to speed up data loading, "
                 << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
                 << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
                 << "according to the current samples.";
  } else if (num_rows_ > LAZY_LOAD_THRESHOLD) {
    load_mode_ = LoadMode::kLazy;
    tasks_.load_mode_ = LoadMode::kLazy;
    MS_LOG(INFO) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
                 << ", enable lazy load mode. If you want to speed up data loading, "
                 << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
                 << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
                 << "according to the current samples.";
  } else {
    load_mode_ = LoadMode::kFast;
    tasks_.load_mode_ = LoadMode::kFast;
  }

  UpdateLoadModeByShuffleMode();

  auto disk_size = page_size_ * row_group_summary.size();
  auto compression_size = shard_header_->GetCompressionSize();
  total_blob_size_ = disk_size + compression_size;
  MS_LOG(INFO) << "The size of blob data on disk: " << disk_size
               << " , additional uncompression size: " << compression_size
               << " , total blob size: " << total_blob_size_;

  MS_LOG(INFO) << "Succeed to get metadata from mindrecord files";

  return Status::OK();
}

void ShardReader::UpdateLoadModeByShuffleMode() {
  // change the load mode by shuffle mode
  // the load mode will generate different task_list, the task_list will be shuffled by different shuffle mode
  for (auto &item : operators_) {
    MS_LOG(INFO) << "The shuffle mode of the operator is " << item->GetShuffleMode();
    std::string info = "Update the load mode from " + LoadModeToStr(load_mode_) + " to ";
    switch (item->GetShuffleMode()) {
      case dataset::ShuffleMode::kGlobal: {
        load_mode_ = LoadMode::kLazy;
        tasks_.load_mode_ = LoadMode::kLazy;
        break;
      }
      case dataset::ShuffleMode::kPartial: {
        load_mode_ = LoadMode::kSlow;
        tasks_.load_mode_ = LoadMode::kSlow;
        break;
      }
      case dataset::ShuffleMode::kFiles: {
        load_mode_ = LoadMode::kLazy;
        tasks_.load_mode_ = LoadMode::kLazy;
        break;
      }
      case dataset::ShuffleMode::kInfile: {
        load_mode_ = LoadMode::kLazy;
        tasks_.load_mode_ = LoadMode::kLazy;
        break;
      }
      case dataset::ShuffleMode::kAdaptive: {
        // update the shuffle mode by load mode when the shuffle mode is adaptive
        if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
          item->UpdateShuffleMode(dataset::ShuffleMode::kGlobal);
        } else if (load_mode_ == LoadMode::kSlow) {
          item->UpdateShuffleMode(dataset::ShuffleMode::kPartial);
        }
        break;
      }
      case dataset::ShuffleMode::kFalse: {
        load_mode_ = LoadMode::kLazy;
        tasks_.load_mode_ = LoadMode::kLazy;
        break;
      }
      default:
        // no need to change the load mode
        break;
    }
    MS_LOG(INFO) << info << LoadModeToStr(load_mode_) << ".";
  }
}

Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
  std::string path_utf8 = "";
#if defined(_WIN32) || defined(_WIN64)
  path_utf8 = FileUtils::GB2312ToUTF_8((file + ".db").data());
#endif
  if (path_utf8.empty()) {
    path_utf8 = file + ".db";
  }

  // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
  // use "unix-none" to avoid flock and achieve better performance on shared storage platform
  CHECK_FAIL_RETURN_UNEXPECTED_MR(
    sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, "unix-none") == SQLITE_OK,
    "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
      ".db exists and do not rename the mindrecord file and meta file.");
#else
  CHECK_FAIL_RETURN_UNEXPECTED_MR(
    sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
    "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
      ".db exists and do not rename the mindrecord file and meta file.");
#endif
  MS_LOG(DEBUG) << "Succeed to open meta file, path: " << file << ".db.";

  // starting a transaction during a read-only select operation can solve the problem of frequently
  // accessing *-journal / *-wal files.
  auto sql_code = sqlite3_exec(*db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
  if (sql_code != SQLITE_OK) {
    sqlite3_free(*db);
    RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: " +
                                std::to_string(sql_code));
  }

  string sql = "SELECT NAME from SHARD_NAME;";
  std::vector<std::vector<std::string>> name;
  char *errmsg = nullptr;
  if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
    std::ostringstream oss;
    oss << "Failed to execute the sql [ " << sql << " ] while verifying meta file, " << errmsg
        << ".\nPlease check the meta file: " + file + ".db";
    sqlite3_free(errmsg);
    sqlite3_close(*db);
    RETURN_STATUS_UNEXPECTED_MR(oss.str());
  } else {
    std::shared_ptr<std::string> fn_ptr;
    RETURN_IF_NOT_OK_MR(GetFileName(file, &fn_ptr));
    if (name.empty() || name[0][0] != *fn_ptr) {
      sqlite3_free(errmsg);
      sqlite3_close(*db);
      RETURN_STATUS_UNEXPECTED_MR("Invalid file, mindrecord meta file: " + file + ".db and mindrecord file: " + file +
                                  " can not match. Please do not rename the mindrecord file or meta file.");
    }
  }
  return Status::OK();
}

Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
  auto schema_ptr = GetShardHeader()->GetSchemas()[0];
  auto schema = schema_ptr->GetSchema()["schema"];
  for (auto i = 0; i < selected_columns.size(); ++i) {
    CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(selected_columns[i]) != schema.end(),
                                    "Invalid data, column name: " + selected_columns[i] +
                                      " can not found in schema. Please check the 'column_list'.");
  }
  return Status::OK();
}

Status ShardReader::Open(int n_consumer) {
  file_streams_random_ =
    std::vector<std::vector<std::shared_ptr<std::fstream>>>(n_consumer, std::vector<std::shared_ptr<std::fstream>>());
  for (const auto &file : file_paths_) {
    for (int j = 0; j < n_consumer; ++j) {
      std::optional<std::string> dir = "";
      std::optional<std::string> local_file_name = "";
      FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
      if (!dir.has_value()) {
        dir = ".";
      }

      auto realpath = FileUtils::GetRealPath(dir.value().c_str());
      CHECK_FAIL_RETURN_UNEXPECTED_MR(
        realpath.has_value(),
        "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);

      std::optional<std::string> whole_path = "";
      FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);

      std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
      fs->open(whole_path.value(), std::ios::in | std::ios::binary);
      if (!fs->good()) {
        fs->close();
        RETURN_STATUS_UNEXPECTED_MR(
          "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
          "open files limit(ulimit -a): " +
          file);
      }
      file_streams_random_[j].push_back(fs);
    }
    MS_LOG(INFO) << "Succeed to open file, path: " << file;
  }
  return Status::OK();
}

Status ShardReader::ExtendRandomFileStreams(const int n_new_consumers) {
  CHECK_FAIL_RETURN_UNEXPECTED_MR(n_new_consumers > 0,
                                  "n_new_consumers must be a positive number. Got: " + std::to_string(n_new_consumers));
  CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
                                  "ExtendRandomFileStreams() must not be called prior to calling Open()");
  // make sure we won't exceed the number of allowed threads.
  uint32_t thread_limit = GetMaxThreadNum();
  CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ + n_new_consumers <= thread_limit,
                                  "Requested increase in number of consumers will cause it to be above the number of "
                                  "allowed threads. n_new_consumers: " +
                                    std::to_string(n_new_consumers) +
                                    ", new n_consumers: " + std::to_string(n_consumer_ + n_new_consumers));

  for (int i = 0; i < n_new_consumers; i++) {
    (void)file_streams_random_.emplace_back(std::vector<std::shared_ptr<std::fstream>>());
  }

  for (const auto &file : file_paths_) {
    std::optional<std::string> dir = "";
    std::optional<std::string> local_file_name = "";
    FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
    if (!dir.has_value()) {
      dir = ".";
    }

    auto realpath = FileUtils::GetRealPath(dir.value().data());
    CHECK_FAIL_RETURN_UNEXPECTED_MR(
      realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);

    std::optional<std::string> whole_path = "";
    FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);

    for (int j = n_consumer_; j < n_consumer_ + n_new_consumers; ++j) {
      std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
      fs->open(whole_path.value(), std::ios::in | std::ios::binary);
      if (!fs->good()) {
        fs->close();
        RETURN_STATUS_UNEXPECTED_MR(
          "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
          "open files limit(ulimit -a): " +
          file);
      }
      file_streams_random_[j].push_back(fs);
    }
    MS_LOG(INFO) << "Succeed to open file, path: " << file;
  }
  n_consumer_ += n_new_consumers;
  MS_LOG(INFO) << "n_consumer_ is increased by " + std::to_string(n_new_consumers) + " to " +
                    std::to_string(n_consumer_);

  return Status::OK();
}

Status ShardReader::ShrinkRandomFileStreams(const int n_remove_consumers) {
  CHECK_FAIL_RETURN_UNEXPECTED_MR(
    n_remove_consumers > 0, "n_remove_consumers must be a positive number. Got: " + std::to_string(n_remove_consumers));
  CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
                                  "ShrinkRandomFileStreams() must not be called prior to calling Open()");
  // make sure we won't go below the number of allowed threads.
  CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ - n_remove_consumers >= kMinConsumerCount,
                                  "Requested decrease in number of consumers will cause it to be below the number of "
                                  "allowed threads. n_remove_consumers: " +
                                    std::to_string(n_remove_consumers) +
                                    ", new n_consumers: " + std::to_string(n_consumer_ - n_remove_consumers));

  for (int i = n_consumer_ - 1; i >= n_consumer_ - n_remove_consumers; i--) {
    for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
      if (file_streams_random_[i][j] != nullptr) {
        file_streams_random_[i][j]->close();
      }
    }
    file_streams_random_.pop_back();
  }
  n_consumer_ -= n_remove_consumers;
  MS_LOG(INFO) << "n_consumer_ is decreased by " + std::to_string(n_remove_consumers) + " to " +
                    std::to_string(n_consumer_);

  return Status::OK();
}

void ShardReader::FileStreamsOperator() {
  for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; --i) {
    if (file_streams_[i] != nullptr) {
      file_streams_[i]->close();
    }
  }
  for (int i = static_cast<int>(file_streams_random_.size()) - 1; i >= 0; --i) {
    for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
      if (file_streams_random_[i][j] != nullptr) {
        file_streams_random_[i][j]->close();
      }
    }
  }
  for (int i = static_cast<int>(database_paths_.size()) - 1; i >= 0; --i) {
    if (database_paths_[i] != nullptr) {
      auto sql_code = sqlite3_exec(database_paths_[i], "END TRANSACTION;", nullptr, nullptr, nullptr);
      if (sql_code != SQLITE_OK) {
        sqlite3_close(database_paths_[i]);
        MS_LOG(ERROR) << "Execute SQL statement `END TRANSACTION;` failed, SQLite result code: "
                      << std::to_string(sql_code);
        continue;
      }
      auto ret = sqlite3_close(database_paths_[i]);
      if (ret != SQLITE_OK) {
        MS_LOG(ERROR) << "[Internal ERROR] Failed to close meta file, " << ret << ".";
      }
      database_paths_[i] = nullptr;
    }
  }
}

ShardReader::~ShardReader() { Close(); }

void ShardReader::Close() {
  {
    std::lock_guard<std::mutex> lck(mtx_delivery_);
    interrupt_ = true;  // interrupt reading and stop threads
  }
  cv_delivery_.notify_all();

  // Wait for all threads to finish
  for (auto &i_thread : thread_set_) {
    if (i_thread.joinable()) {
      i_thread.join();
    }
  }

  FileStreamsOperator();
}

std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }

std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }

int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }

int64_t ShardReader::GetNumRows() const { return num_rows_; }

int64_t ShardReader::GetNumRowsAfterSampling() const { return tasks_.SizeAfterSampling(); }

std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
  std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
  int shard_count = shard_header_->GetShardCount();
  if (shard_count <= 0) {
    return row_group_summary;
  }

  uint32_t total_count = 0;
  for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
    // return -1 when page's size equals to 0.
    auto last_page_id = shard_header_->GetLastPageId(shard_id);
    if (static_cast<int>(last_page_id) == -1) {
      // Empty mindrecord file which does not contain any samples
      MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id]
                      << " does not contain any samples, pls remove it.";
      row_group_summary.emplace_back(shard_id, 0, 0, 0);
      shard_sample_count_.push_back(total_count);
      continue;
    }
    for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
      std::shared_ptr<Page> page_ptr;
      (void)shard_header_->GetPage(shard_id, page_id, &page_ptr);
      if (page_ptr->GetPageType() != kPageTypeBlob) {
        continue;
      }
      uint64_t start_row_id = page_ptr->GetStartRowID();
      if (start_row_id > page_ptr->GetEndRowID()) {
        return std::vector<std::tuple<int, int, int, uint64_t>>();
      }
      uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id;
      total_count += number_of_rows;
      row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows);
    }
    shard_sample_count_.push_back(total_count);
  }

  return row_group_summary;
}

Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
                                       std::shared_ptr<std::fstream> fs,
                                       std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
                                       int shard_id, const std::vector<std::string> &columns,
                                       std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
  auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
  for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
    try {
      uint64_t group_id = std::stoull(labels[i][0]);
      uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
      uint64_t offset_end = std::stoull(labels[i][2]);
      CHECK_FAIL_RETURN_UNEXPECTED_MR(offset_end >= offset_start,
                                      "The sample's end offset: " + std::to_string(offset_end) +
                                        " should >= start offset: " + std::to_string(offset_start) + ", check fail.");
      (*offset_ptr)[shard_id].emplace_back(
        std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
      if (!all_in_index_) {
        int raw_page_id = std::stoi(labels[i][3]);
        uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
        uint64_t label_end = std::stoull(labels[i][5]);
        CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
                                        "The sample's end offset: " + std::to_string(label_end) +
                                          " should >= start offset: " + std::to_string(label_start) + ", check fail.");
        auto len = label_end - label_start;
        auto label_raw = std::vector<uint8_t>(len);
        auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
        if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
          fs->close();
          RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
        }
        auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
        if (!io_read.good() || io_read.fail() || io_read.bad()) {
          fs->close();
          RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
        }
        json label_json = json::from_msgpack(label_raw);
        json tmp;
        if (!columns.empty()) {
          for (const auto &col : columns) {
            if (label_json.find(col) != label_json.end()) {
              tmp[col] = label_json[col];
            }
          }
        } else {
          tmp = label_json;
        }
        (*col_val_ptr)[shard_id].emplace_back(tmp);
      } else {
        json construct_json;
        RETURN_IF_NOT_OK_MR(ConvertJsonValue(labels[i], columns, schema, &construct_json));
        (*col_val_ptr)[shard_id].emplace_back(construct_json);
      }
    } catch (std::out_of_range &e) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
                                  std::string(e.what()));
    } catch (std::invalid_argument &e) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
                                  std::string(e.what()));
    } catch (...) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Unexpected exception raised in ConvertLabelToJson function.");
    }
  }

  return Status::OK();
}

Status ShardReader::ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
                                     const json &schema, json *value) {
  constexpr int64_t index = 3;
  for (unsigned int j = 0; j < columns.size(); ++j) {
    if (schema[columns[j]]["type"] == "int32") {
      (*value)[columns[j]] = StringToNum<int32_t>(label[j + index]);
    } else if (schema[columns[j]]["type"] == "int64") {
      (*value)[columns[j]] = StringToNum<int64_t>(label[j + index]);
    } else if (schema[columns[j]]["type"] == "float32") {
      (*value)[columns[j]] = StringToNum<float>(label[j + index]);
    } else if (schema[columns[j]]["type"] == "float64") {
      (*value)[columns[j]] = StringToNum<double>(label[j + index]);
    } else {
      (*value)[columns[j]] = std::string(label[j + index]);
    }
  }
  return Status::OK();
}
Status ShardReader::ReadAllRowsInShard(int shard_id, const int32_t &consumer_id, const std::string &sql,
                                       const std::vector<std::string> &columns,
                                       std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
                                       std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
  auto db = database_paths_[shard_id];
  std::vector<std::vector<std::string>> labels;
  char *errmsg = nullptr;
  int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
  if (rc != SQLITE_OK) {
    std::ostringstream oss;
    oss << "[Internal ERROR] Failed to execute the sql [ " << sql << " ] while reading meta file, " << errmsg;
    sqlite3_free(errmsg);
    sqlite3_close(db);
    db = nullptr;
    RETURN_STATUS_UNEXPECTED_MR(oss.str());
  }
  MS_LOG(DEBUG) << "Succeed to get " << labels.size() << " records from shard " << std::to_string(shard_id)
                << " index.";

  sqlite3_free(errmsg);
  return ConvertLabelToJson(labels, file_streams_random_[consumer_id][shard_id], offset_ptr, shard_id, columns,
                            col_val_ptr);
}

Status ShardReader::GetAllClasses(const std::string &category_field,
                                  std::shared_ptr<std::set<std::string>> category_ptr) {
  std::map<std::string, uint64_t> index_columns;
  for (auto &field : GetShardHeader()->GetFields()) {
    index_columns[field.second] = field.first;
  }
  CHECK_FAIL_RETURN_UNEXPECTED_MR(
    index_columns.find(category_field) != index_columns.end(),
    "Invalid data, 'class_column': " + category_field +
      " can not found in fields of mindrecord files. Please check 'class_column' in PKSampler.");
  std::shared_ptr<std::string> fn_ptr;
  RETURN_IF_NOT_OK_MR(
    ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
  std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
  std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
  for (int x = 0; x < shard_count_; x++) {
    threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
  }

  for (int x = 0; x < shard_count_; x++) {
    threads[x].join();
  }
  return Status::OK();
}

void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
                                    std::shared_ptr<std::set<std::string>> category_ptr) {
  if (db == nullptr) {
    return;
  }
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
  pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_id)).c_str());
#endif
  std::vector<std::vector<std::string>> columns;
  char *errmsg = nullptr;
  int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
  if (ret != SQLITE_OK) {
    sqlite3_free(errmsg);
    sqlite3_close(db);
    db = nullptr;
    MS_LOG(ERROR) << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
                  << " ] while reading meta file, " << errmsg;
    return;
  }
  MS_LOG(INFO) << "Succeed to get " << columns.size() << " records from shard " << std::to_string(shard_id)
               << " index.";
  std::lock_guard<std::mutex> lck(shard_locker_);
  for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
    category_ptr->emplace(columns[i][0]);
  }
  sqlite3_free(errmsg);
}

Status ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns,
                                    std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
  std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
  auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
    shard_count_, std::vector<std::vector<uint64_t>>{});
  auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});

  if (all_in_index_) {
    for (unsigned int i = 0; i < columns.size(); ++i) {
      fields += ',';
      std::shared_ptr<std::string> fn_ptr;
      RETURN_IF_NOT_OK_MR(
        ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
      fields += *fn_ptr;
    }
  } else {  // fetch raw data from Raw page while some field is not index.
    fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
  }

  std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;";

  std::vector<std::future<Status>> async_results;
  auto status = Status::OK();
  for (int x = 0; x < shard_count_; x++) {
    async_results.push_back(std::async(std::launch::async, &ShardReader::ReadAllRowsInShard, this, x, 0, sql, columns,
                                       offset_ptr, col_val_ptr));
  }

  for (auto i = 0; i < async_results.size(); i++) {
    auto res = async_results[i].get();
    if (res.IsError() && status.IsOk()) {
      status = res;
    }
  }
  if (status.IsError()) {
    return status;
  }
  *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
  return Status::OK();
}

Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
                                                     const int32_t &consumer_id, const uint32_t &sample_id,
                                                     std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
  std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
  auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
    shard_count_, std::vector<std::vector<uint64_t>>{});
  auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
  if (all_in_index_) {
    for (unsigned int i = 0; i < columns.size(); ++i) {
      fields += ',';
      std::shared_ptr<std::string> fn_ptr;
      RETURN_IF_NOT_OK_MR(
        ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
      fields += *fn_ptr;
    }
  } else {  // fetch raw data from Raw page while some field is not index.
    fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
  }

  std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);

  RETURN_IF_NOT_OK_MR(ReadAllRowsInShard(shard_id, consumer_id, sql, columns, offset_ptr, col_val_ptr));
  *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
  return Status::OK();
}

Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
                                      std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
  std::shared_ptr<Page> page_ptr;
  RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
  std::string file_name = file_paths_[shard_id];
  uint64_t page_length = page_ptr->GetPageSize();
  uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
  std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id);
  auto labels_ptr = std::make_shared<std::vector<json>>();
  RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr));
  *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
                                                           std::move(*labels_ptr));
  return Status::OK();
}

Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
                                         const std::pair<std::string, std::string> &criteria,
                                         const std::vector<std::string> &columns,
                                         std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
  std::shared_ptr<Page> page_ptr;
  RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
  vector<string> criteria_list{criteria.first};
  RETURN_IF_NOT_OK_MR(CheckColumnList(criteria_list));
  std::string file_name = file_paths_[shard_id];
  uint64_t page_length = page_ptr->GetPageSize();
  uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
  std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria);
  if (image_offset.empty()) {
    *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>();
  }
  auto labels_ptr = std::make_shared<std::vector<json>>();
  RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr));
  *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
                                                           std::move(*labels_ptr));
  return Status::OK();
}

int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) {
  auto *records = static_cast<std::vector<std::vector<std::string>> *>(p_data);
  if (num_fields > 0 && num_fields <= kMaxFieldCount) {
    for (int i = 0; i < num_fields; ++i) {
      if (p_fields[i] == nullptr) {
        p_fields[i] = const_cast<char *>("");
      }
    }
  }
  records->emplace_back(p_fields, p_fields + num_fields);
  return 0;
}

std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int shard_id,
                                                               const std::pair<std::string, std::string> &criteria) {
  auto db = database_paths_[shard_id];

  std::string sql = "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";

  // whether use index search
  if (!criteria.first.empty()) {
    sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
  }
  sql += ";";
  std::vector<std::vector<std::string>> image_offsets;

  sqlite3_stmt *stmt = nullptr;
  if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
    MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to prepare statement [ " << sql << " ].";
  }

  // bind the PAGE_ID_BLOB
  int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
  if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
    (void)sqlite3_finalize(stmt);
    MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
                      << ", value: " << std::to_string(page_id);
  }

  // bind the criteria
  if (!criteria.first.empty()) {
    index = sqlite3_bind_parameter_index(stmt, ":criteria");
    if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
      (void)sqlite3_finalize(stmt);
      MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
                        << ", value: " + criteria.second;
    }
  }

  int rc = sqlite3_step(stmt);
  while (rc != SQLITE_DONE) {
    vector<string> tmp;
    int ncols = sqlite3_column_count(stmt);
    for (int i = 0; i < ncols; i++) {
      tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
    }
    image_offsets.push_back(tmp);
    rc = sqlite3_step(stmt);
  }

  auto finalize = sqlite3_finalize(stmt);
  if (finalize != SQLITE_OK) {
    MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to finalize sql stmt, error code: " << std::to_string(finalize);
  }

  MS_LOG(DEBUG) << "Succeed to get " << image_offsets.size() << " records from index.";

  std::vector<std::vector<uint64_t>> res;
  for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) {
    res.emplace_back(std::vector<uint64_t>{0, 0});
  }
  for (int i = 0; i < static_cast<int>(image_offsets.size()); i++) {
    const auto &image_offset = image_offsets[i];
    res[i][0] = std::stoull(image_offset[0]) + kInt64Len;
    res[i][1] = std::stoull(image_offset[1]);
    if (res[i][1] < res[i][0]) {
      MS_LOG(EXCEPTION) << "The sample's end offset: " << std::to_string(res[i][1])
                        << " should >= start offset: " << std::to_string(res[i][0]) << ", check fail.";
    }
  }
  return res;
}

Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
                                       std::shared_ptr<std::vector<uint64_t>> *pages_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
  auto db = database_paths_[shard_id];

  std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";

  if (!criteria.first.empty()) {
    sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
  }
  sql += ";";
  std::vector<std::vector<std::string>> page_ids;

  sqlite3_stmt *stmt = nullptr;
  if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
    (void)sqlite3_finalize(stmt);
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
  }

  if (!criteria.first.empty()) {
    int index = sqlite3_bind_parameter_index(stmt, ":criteria");
    if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
      (void)sqlite3_finalize(stmt);
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
                                  std::to_string(index) + ", value: " + criteria.second);
    }
  }

  int rc = sqlite3_step(stmt);
  while (rc != SQLITE_DONE) {
    vector<string> tmp;
    int ncols = sqlite3_column_count(stmt);
    for (int i = 0; i < ncols; i++) {
      tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
    }
    page_ids.push_back(tmp);
    rc = sqlite3_step(stmt);
  }

  auto finalize = sqlite3_finalize(stmt);
  if (finalize != SQLITE_OK) {
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
                                std::to_string(finalize));
  }

  MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << " pages from index.";
  for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
    (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
  }
  return Status::OK();
}

std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
  std::vector<std::string> blob_fields;
  for (auto &p : GetShardHeader()->GetSchemas()) {
    // assume one schema
    const auto &fields = p->GetBlobFields();
    blob_fields.assign(fields.begin(), fields.end());
    break;
  }
  return std::make_pair(kCV, blob_fields);
}

void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
  // assume different schemas do not contain same key.
  if (columns.empty()) {
    all_in_index_ = false;
    return;
  }
  for (auto &field : GetShardHeader()->GetFields()) {
    column_schema_id_[field.second] = field.first;
  }
  for (auto &col : columns) {
    if (column_schema_id_.find(col) == column_schema_id_.end()) {
      all_in_index_ = false;
      return;
    }
  }
}

Status ShardReader::QueryWithPageIdBlobAndCriteria(sqlite3 *db, const string &sql, const int &page_id,
                                                   const string &criteria,
                                                   std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
  sqlite3_stmt *stmt = nullptr;
  if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
  }

  // bind the PAGE_ID_BLOB
  int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
  if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
    (void)sqlite3_finalize(stmt);
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
                                std::to_string(index) + ", value: " + std::to_string(page_id));
  }

  // bind the criteria
  index = sqlite3_bind_parameter_index(stmt, ":criteria");
  if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
    (void)sqlite3_finalize(stmt);
    RETURN_STATUS_UNEXPECTED_MR(
      "[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: " + criteria);
  }
  int rc = sqlite3_step(stmt);
  while (rc != SQLITE_DONE) {
    vector<string> tmp;
    int ncols = sqlite3_column_count(stmt);
    for (int i = 0; i < ncols; i++) {
      tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
    }
    labels_ptr->push_back(tmp);
    rc = sqlite3_step(stmt);
  }

  auto finalize = sqlite3_finalize(stmt);
  if (finalize != SQLITE_OK) {
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
                                std::to_string(finalize));
  }
  return Status::OK();
}

Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
                                            const std::vector<std::vector<std::string>> &label_offsets,
                                            std::shared_ptr<std::vector<json>> *labels_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
  std::shared_ptr<std::fstream> fs = file_streams_random_[0][shard_id];

  // init the return
  for (unsigned int i = 0; i < label_offsets.size(); ++i) {
    (*labels_ptr)->emplace_back(json{});
  }

  for (unsigned int i = 0; i < label_offsets.size(); ++i) {
    const auto &labelOffset = label_offsets[i];
    if (labelOffset.size() < 3) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] 'labelOffset' size should be less than 3 but got: " +
                                  std::to_string(labelOffset.size()) + ".");
    }
    uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
    uint64_t label_end = std::stoull(labelOffset[2]);
    CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
                                    "The sample's end offset: " + std::to_string(label_end) +
                                      " should >= start offset: " + std::to_string(label_start) + ", check fail.");
    int raw_page_id = std::stoi(labelOffset[0]);
    auto len = label_end - label_start;
    auto label_raw = std::vector<uint8_t>(len);
    auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
    if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file, path: " + file_paths_[shard_id]);
    }

    auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
    if (!io_read.good() || io_read.fail() || io_read.bad()) {
      fs->close();
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file, path: " + file_paths_[shard_id]);
    }

    json label_json = json::from_msgpack(label_raw);
    json tmp = label_json;
    for (auto &col : columns) {
      if (label_json.find(col) != label_json.end()) {
        tmp[col] = label_json[col];
      }
    }
    (*(*labels_ptr))[i] = tmp;
  }
  return Status::OK();
}

Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
                                      const std::pair<std::string, std::string> &criteria,
                                      std::shared_ptr<std::vector<json>> *labels_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
  // get page info from sqlite
  auto db = database_paths_[shard_id];
  std::string sql =
    "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";

  auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
  if (!criteria.first.empty()) {
    sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria;";
    RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, label_offset_ptr));
  } else {
    sql += ";";
    sqlite3_stmt *stmt = nullptr;
    if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
    }

    // bind the PAGE_ID_BLOB
    int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
    if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
      (void)sqlite3_finalize(stmt);
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
                                  std::to_string(index) + ", value: " + std::to_string(page_id));
    }

    int rc = sqlite3_step(stmt);
    while (rc != SQLITE_DONE) {
      vector<string> tmp;
      int ncols = sqlite3_column_count(stmt);
      for (int i = 0; i < ncols; i++) {
        tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
      }
      label_offset_ptr->push_back(tmp);
      rc = sqlite3_step(stmt);
    }

    auto finalize = sqlite3_finalize(stmt);
    if (finalize != SQLITE_OK) {
      RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
                                  std::to_string(finalize));
    }

    MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
  }
  // get labels from binary file
  return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr);
}

Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
                              const std::pair<std::string, std::string> &criteria,
                              std::shared_ptr<std::vector<json>> *labels_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
  if (all_in_index_) {
    auto db = database_paths_[shard_id];
    std::string fields;
    for (unsigned int i = 0; i < columns.size(); ++i) {
      if (i > 0) {
        fields += ',';
      }
      uint64_t schema_id = column_schema_id_[columns[i]];
      fields += columns[i] + "_" + std::to_string(schema_id);
    }
    if (fields.empty()) {
      fields = "*";
    }
    auto labels = std::make_shared<std::vector<std::vector<std::string>>>();
    std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
    if (!criteria.first.empty()) {
      sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria;";
      RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, labels));
    } else {
      sql += ";";
      sqlite3_stmt *stmt = nullptr;
      if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
        RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
      }

      // bind the PAGE_ID_BLOB
      int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
      if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
        (void)sqlite3_finalize(stmt);
        RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
                                    std::to_string(index) + ", value: " + std::to_string(page_id));
      }

      int rc = sqlite3_step(stmt);
      while (rc != SQLITE_DONE) {
        vector<string> tmp;
        int ncols = sqlite3_column_count(stmt);
        for (int i = 0; i < ncols; i++) {
          tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
        }
        labels->push_back(tmp);
        rc = sqlite3_step(stmt);
      }

      auto finalize = sqlite3_finalize(stmt);
      if (finalize != SQLITE_OK) {
        RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
                                    std::to_string(finalize));
      }

      MS_LOG(DEBUG) << "Succeed to get " << labels->size() << " records from index.";
    }
    for (unsigned int i = 0; i < labels->size(); ++i) {
      (*labels_ptr)->emplace_back(json{});
    }
    for (unsigned int i = 0; i < labels->size(); ++i) {
      json construct_json;
      for (unsigned int j = 0; j < columns.size(); ++j) {
        // construct json "f1": value
        auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];

        // convert the string to base type by schema
        if (schema[columns[j]]["type"] == "int32") {
          construct_json[columns[j]] = StringToNum<int32_t>((*labels)[i][j]);
        } else if (schema[columns[j]]["type"] == "int64") {
          construct_json[columns[j]] = StringToNum<int64_t>((*labels)[i][j]);
        } else if (schema[columns[j]]["type"] == "float32") {
          construct_json[columns[j]] = StringToNum<float>((*labels)[i][j]);
        } else if (schema[columns[j]]["type"] == "float64") {
          construct_json[columns[j]] = StringToNum<double>((*labels)[i][j]);
        } else {
          construct_json[columns[j]] = std::string((*labels)[i][j]);
        }
      }
      (*(*labels_ptr))[i] = construct_json;
    }
    return Status::OK();
  }
  return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr);
}

bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) {
  return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
}

int64_t ShardReader::GetNumClasses(const std::string &category_field) {
  auto shard_count = file_paths_.size();
  auto index_fields = shard_header_->GetFields();

  std::map<std::string, int64_t> map_schema_id_fields;
  for (auto &field : index_fields) {
    map_schema_id_fields[field.second] = field.first;
  }

  if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
    MS_LOG(ERROR) << "[Internal ERROR] 'category_field' " << category_field
                  << " can not found in index fields of mindrecord files.";
    return -1;
  }
  std::shared_ptr<std::string> fn_ptr;
  (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field),
                                               &fn_ptr);
  std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
  std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
  auto category_ptr = std::make_shared<std::set<std::string>>();
  sqlite3 *db = nullptr;
  for (int x = 0; x < shard_count; x++) {
    std::string path_utf8 = "";
#if defined(_WIN32) || defined(_WIN64)
    path_utf8 = FileUtils::GB2312ToUTF_8((file_paths_[x] + ".db").data());
#endif
    if (path_utf8.empty()) {
      path_utf8 = file_paths_[x] + ".db";
    }

#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
    // use "unix-none" to avoid flock and achieve better performance on shared storage platform
    int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, "unix-none");
#else
    int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, nullptr);
#endif
    if (SQLITE_OK != rc) {
      MS_LOG(ERROR) << "[Internal ERROR] Failed to open meta file: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
      return -1;
    }

    // starting a transaction during a read-only select operation can solve the problem of frequently
    // accessing *-journal / *-wal files.
    auto sql_code = sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
    if (sql_code != SQLITE_OK) {
      sqlite3_free(db);
      MS_LOG(ERROR) << "Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: "
                    << std::to_string(sql_code);
      return -1;
    }
    threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
  }

  for (int x = 0; x < shard_count; x++) {
    threads[x].join();
  }
  sqlite3_close(db);
  return category_ptr->size();
}

Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
                                   const std::shared_ptr<ShardOperator> &ops, int64_t *count,
                                   const int64_t num_padded) {
  RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
  int64_t num_samples = num_rows_;
  bool root = true;
  std::stack<std::shared_ptr<ShardOperator>> stack_ops;
  std::shared_ptr<ShardOperator> op(ops);
  while (op != nullptr) {
    stack_ops.push(op);
    op = op->GetChildOp();
  }
  while (!stack_ops.empty()) {
    op = stack_ops.top();
    stack_ops.pop();
    if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
      num_samples = op->GetNumSamples(num_samples, 0);
      if (num_padded > 0 && root == true) {
        num_samples += num_padded;
        root = false;
      }
    } else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
      auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
      std::string category_field = category_op->GetCategoryField();
      auto num_classes = GetNumClasses(category_field);
      num_samples = category_op->GetNumSamples(num_samples, num_classes);
      if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
        auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
        if (tmp != 0 && num_samples != -1) {
          num_samples = std::min(num_samples, tmp);
        }

        CHECK_FAIL_RETURN_UNEXPECTED_MR(num_samples != -1,
                                        "Invalid data, 'num_samples': " + std::to_string(num_samples) +
                                          " is out of bound: " + std::to_string(std::numeric_limits<int64_t>::max()));
      }
    } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
      if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
        auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
        if (root == true) {
          sampler_op->SetNumPaddedSamples(num_padded);
          num_samples = op->GetNumSamples(num_samples, 0);
          CHECK_FAIL_RETURN_UNEXPECTED_MR(
            num_samples != -1,
            "Invalid data, the size of dataset and padded samples: " + std::to_string(num_padded) +
              " can not be divisible by the value of 'num_shards'.\n Please adjust the value of 'num_padded'.");
          root = false;
        }
      } else {
        num_samples = op->GetNumSamples(num_samples, 0);
        num_samples += num_padded;
      }
    } else {
      if (num_padded > 0) {
        num_samples += num_padded;
      }
    }
  }
  *count = num_samples;
  return Status::OK();
}

Status ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
                         const std::vector<std::string> &selected_columns,
                         const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded,
                         LoadMode load_mode) {
  load_mode_ = load_mode;
  operators_ = operators;

  // Open file and set header by ShardReader
  RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
  auto thread_limit = GetMaxThreadNum();
  if (n_consumer > thread_limit) {
    n_consumer = thread_limit;
  }
  if (n_consumer < kMinConsumerCount) {
    n_consumer = kMinConsumerCount;
  }

  selected_columns_ = selected_columns;
  RETURN_IF_NOT_OK_MR(CheckColumnList(selected_columns_));

  // Initialize argument
  shard_count_ = static_cast<int>(file_paths_.size());
  n_consumer_ = n_consumer;
  num_padded_ = num_padded;

  RETURN_IF_NOT_OK_MR(Open(n_consumer));
  return Status::OK();
}

Status ShardReader::Launch(bool is_sample_read) {
  // Get all row groups' info
  auto row_group_summary = ReadRowGroupSummary();

  // Sort row group by (group_id, shard_id), prepare for parallel reading
  std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
  auto status = CreateTasks(row_group_summary, operators_);
  if (status.IsError()) {
    interrupt_ = true;
    return status;
  }
  if (is_sample_read) {
    return Status::OK();
  }
  // Start provider consumer threads
  thread_set_ = std::vector<std::thread>(n_consumer_);
  CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
                                  "Invalid data, 'num_parallel_workers' should be less than or equal to " +
                                    std::to_string(kMaxConsumerCount) + "but got: " + std::to_string(n_consumer_));

  for (int x = 0; x < n_consumer_; ++x) {
    thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
  }

  MS_LOG(INFO) << "Succeed to launch read thread.";
  return Status::OK();
}

Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
  CheckIfColumnInIndex(selected_columns_);
  auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
  auto categories = category_op->GetCategories();
  int64_t num_elements = category_op->GetNumElements();
  int64_t num_samples = 0;
  if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
    num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
    CHECK_FAIL_RETURN_UNEXPECTED_MR(
      num_samples >= 0,
      "Invalid data, 'num_samples' should be greater than or equal to 0, but got: " + std::to_string(num_samples));
  }
  CHECK_FAIL_RETURN_UNEXPECTED_MR(
    num_elements > 0,
    "[Internal ERROR] 'num_elements' should be greater than 0, but got: " + std::to_string(num_elements));
  if (categories.empty() == true) {
    std::string category_field = category_op->GetCategoryField();
    int64_t num_categories = category_op->GetNumCategories();
    CHECK_FAIL_RETURN_UNEXPECTED_MR(
      num_categories > 0,
      "[Internal ERROR] 'num_categories' should be greater than 0, but got: " + std::to_string(num_categories));
    auto category_ptr = std::make_shared<std::set<std::string>>();
    RETURN_IF_NOT_OK_MR(GetAllClasses(category_field, category_ptr));
    int i = 0;
    for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
      categories.emplace_back(category_field, *it);
      i++;
    }
  }
  // Generate a vector of task lists.  Each catogory has a list of tasks.
  std::vector<ShardTaskList> categoryTasks(categories.size());
  for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
    int category_index = 0;
    for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
      auto pages_ptr = std::make_shared<std::vector<uint64_t>>();
      RETURN_IF_NOT_OK_MR(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr));
      for (const auto &page_id : *pages_ptr) {
        if (category_index >= num_elements) {
          break;
        }
        std::shared_ptr<Page> page_ptr;
        RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, page_id, &page_ptr));
        auto group_id = page_ptr->GetPageTypeID();
        std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
        RETURN_IF_NOT_OK_MR(
          ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr));
        auto offsets = std::get<3>(*row_group_brief_ptr);

        auto number_of_rows = offsets.size();
        for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
          if (category_index < num_elements) {
            categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
                                                 std::get<3>(*row_group_brief_ptr)[iStart],
                                                 std::get<4>(*row_group_brief_ptr)[iStart]);
            category_index++;
          }
        }
        MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks.";
      }
    }
  }
  tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);

  tasks_.InitSampleIds();
  RETURN_IF_NOT_OK_MR((*category_op)(tasks_));
  return Status::OK();
}

Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
                                     const std::vector<std::shared_ptr<ShardOperator>> &operators) {
  CheckIfColumnInIndex(selected_columns_);
  std::shared_ptr<ROW_GROUPS> row_group_ptr;
  RETURN_IF_NOT_OK_MR(ReadAllRowGroup(selected_columns_, &row_group_ptr));
  auto &offsets = std::get<0>(*row_group_ptr);
  auto &local_columns = std::get<1>(*row_group_ptr);
  int sample_count = 0;
  for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
    sample_count += offsets[shard_id].size();
  }
  CHECK_FAIL_RETURN_UNEXPECTED_MR(sample_count == num_rows_, "Unequal number of index entries and data entries.");
  MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";

  // Init the tasks_ size
  tasks_.ResizeTask(sample_count);

  // Init the task threads, maybe use ThreadPool is better
  std::vector<std::thread> init_tasks_thread(shard_count_);

  uint32_t current_offset = 0;
  for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
    init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
      pthread_setname_np(pthread_self(), std::string("ParallelCreateTasks" + std::to_string(shard_id)).c_str());
#endif
      auto offset = current_offset;
      for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
        tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
                          std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
                          local_columns[shard_id][i]);
        offset++;
      }
    });
    current_offset += offsets[shard_id].size();
  }

  for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
    init_tasks_thread[shard_id].join();
  }
  return Status::OK();
}

Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
                                         const std::vector<std::shared_ptr<ShardOperator>> &operators) {
  CheckIfColumnInIndex(selected_columns_);
  uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
  MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";

  // Init the tasks_ size
  tasks_.ResizeTask(sample_count);

  // Init the task threads, maybe use ThreadPool is better
  std::vector<std::thread> init_tasks_thread(shard_count_);

  for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
    // the offset indicate the shard start
    uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];

    // the count indicate the number of samples in the shard
    uint32_t shard_count =
      shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
    init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
      pthread_setname_np(pthread_self(), std::string("ParallelCreateLazyTasks" + std::to_string(shard_id)).c_str());
#endif
      for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
        // here "i - current_offset" indicate the sample id in the shard
        tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
      }
    });
  }

  for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
    init_tasks_thread[shard_id].join();
  }
  return Status::OK();
}

Status ShardReader::CreateSlowTasksByRow() {
  CheckIfColumnInIndex(selected_columns_);
  uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
  MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
  tasks_.padded_sample_ = num_padded_;
  tasks_.SetShardSampleCount(shard_sample_count_);
  return Status::OK();
}

Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
                                const std::vector<std::shared_ptr<ShardOperator>> &operators) {
  int category_operator = -1;
  for (uint32_t i = 0; i < operators.size(); ++i) {
    const auto &op = operators[i];
    if (std::dynamic_pointer_cast<ShardCategory>(op)) {
      category_operator = static_cast<int>(i);
      break;
    }
  }

  if (-1 == category_operator) {
    if (load_mode_ != LoadMode::kSlow) {
      try {
        if (load_mode_ == LoadMode::kLazy) {
          RETURN_IF_NOT_OK_MR(CreateLazyTasksByRow(row_group_summary, operators));
        } else {
          RETURN_IF_NOT_OK_MR(CreateTasksByRow(row_group_summary, operators));
        }

        // need padded sample to the task
        if (num_padded_ > 0) {
          for (auto i = 0; i < num_padded_; ++i) {
            tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
          }
        }
      } catch (std::bad_alloc &ba) {
        MS_LOG(EXCEPTION) << "bad_alloc caught: " << ba.what() << ". Out of memory, please use parameter "
                          << "shuffle=Shuffle.PARTIAL do shuffle in MindDataset(...) / RandomSampler(...) / "
                          << "DistributedSampler(...) which will use less memory.";
      }
    } else {
      RETURN_IF_NOT_OK_MR(CreateSlowTasksByRow());
    }
  } else {
    RETURN_IF_NOT_OK_MR(CreateTasksByCategory(operators[category_operator]));
  }

  MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
  if (load_mode_ != LoadMode::kSlow) {
    tasks_.InitSampleIds();

    for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
      const auto &op = operators[operator_no];
      if (std::dynamic_pointer_cast<ShardCategory>(op)) {
        continue;
      }

      if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
        op->SetShardSampleCount(shard_sample_count_);
      }
      RETURN_IF_NOT_OK_MR((*op)(tasks_));
    }

    if (tasks_.permutation_.empty()) {
      tasks_.MakePerm();
    }
  } else {
    for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
      const auto &op = operators[operator_no];
      CHECK_FAIL_RETURN_UNEXPECTED_MR(
        !std::dynamic_pointer_cast<ShardCategory>(op),
        "[Internal ERROR] The retrieval function is not available when in slow loading mode.");
      if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
        op->SetShardSampleCount(shard_sample_count_);
      }
      RETURN_IF_NOT_OK_MR((*op)(tasks_));
    }
  }

  num_rows_ = tasks_.Size();
  MS_LOG(INFO) << "The total number of samples is " << num_rows_
               << ", the number of samples after sampling is: " << tasks_.SizeAfterSampling();

  return Status::OK();
}

Status ShardReader::ConsumerOneTask(int64_t task_id, uint32_t consumer_id,
                                    std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(task_content_ptr);
  if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
    // All tasks are done
    CHECK_FAIL_RETURN_UNEXPECTED_MR(task_id < tasks_.Size(), "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
                                                               " is out of bound: " + std::to_string(tasks_.Size()));
  } else {
    CHECK_FAIL_RETURN_UNEXPECTED_MR(
      task_id < (num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]),
      "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
        " is out of bound: " + std::to_string(num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]));
  }

  uint32_t shard_id = 0;
  uint32_t group_id = 0;
  uint32_t blob_start = 0;
  uint32_t blob_end = 0;
  json var_fields;
  // Pick up task from task list
  ShardTask task = tasks_.GetTaskByID(task_id);

  // check task type
  auto task_type = std::get<0>(task);
  if (task_type == TaskType::kPaddedTask) {
    *task_content_ptr =
      std::make_shared<TASK_CONTENT>(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
    return Status::OK();
  }

  shard_id = std::get<0>(std::get<1>(task));  // shard id

  if (load_mode_ == LoadMode::kLazy || load_mode_ == LoadMode::kSlow) {
    // get scalar variable fields by sample id
    uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));

    // read the meta from index
    std::shared_ptr<ROW_GROUPS> row_group_ptr;
    RETURN_IF_NOT_OK_MR(
      ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, consumer_id, sample_id_in_shard, &row_group_ptr));
    auto &offsets = std::get<0>(*row_group_ptr);
    auto &local_columns = std::get<1>(*row_group_ptr);

    group_id = offsets[shard_id][0][1];       // group_id
    blob_start = offsets[shard_id][0][2];     // blob start
    blob_end = offsets[shard_id][0][3];       // blob end
    var_fields = local_columns[shard_id][0];  // scalar variable field
  } else {
    group_id = std::get<1>(std::get<1>(task));  // group id
    blob_start = std::get<2>(task)[0];          // blob start
    blob_end = std::get<2>(task)[1];            // blob end
    var_fields = std::get<3>(task);             // scalar variable field
  }

  // read the blob from data file
  std::shared_ptr<Page> page_ptr;
  RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
  MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;

  // Pack image list
  std::vector<uint8_t> images(blob_end - blob_start);
  auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start;

  auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
  if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
    file_streams_random_[consumer_id][shard_id]->close();
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
  }
  auto &io_read =
    file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
  if (!io_read.good() || io_read.fail() || io_read.bad()) {
    file_streams_random_[consumer_id][shard_id]->close();
    RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
  }

  // Deliver batch data to output map
  std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
  batch.emplace_back(std::move(images), std::move(var_fields));

  *task_content_ptr = std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::move(batch));
  return Status::OK();
}

void ShardReader::ConsumerByRow(int consumer_id) {
  // Set thread name
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
  pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(consumer_id)).c_str());
#endif

  // Loop forever
  for (;;) {
    int64_t sample_id_pos = 0;

    // Get next task ID
    sample_id_pos = sample_id_position_++;

    auto task_content_ptr =
      std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
    int64_t task_id = 0;

    if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
      // All tasks are done
      if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
        return;
      }
      task_id = tasks_.sample_ids_[sample_id_pos];
    } else {
      // task_id is not correct when slow load mode
      if (sample_id_pos >= shard_sample_count_[shard_sample_count_.size() - 1]) {
        return;
      }
      task_id = sample_id_pos;
    }
    if (ConsumerOneTask(task_id, consumer_id, &task_content_ptr).IsError()) {
      MS_LOG(ERROR) << "[Internal ERROR] Error raised in ConsumerOneTask function.";
      interrupt_ = true;
      cv_iterator_.notify_one();
      return;
    }
    const auto &batch = (*task_content_ptr).second;
    // Hanging if maximum map size exceeded
    //   otherwise, set batch data in map
    {
      std::unique_lock<std::mutex> lck(mtx_delivery_);
      cv_delivery_.wait(lck,
                        [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
      if (interrupt_) {
        return;
      }
      delivery_map_[sample_id_pos] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(batch);
    }
    cv_iterator_.notify_one();
  }
}

std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> ShardReader::GetNext() {
  if (interrupt_) {
    return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
  }

  if (deliver_id_ >= static_cast<int>(tasks_.SizeAfterSampling())) {
    return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
  }

  std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> res;
  {
    std::unique_lock<std::mutex> lck(mtx_delivery_);
    cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); });
    if (interrupt_) {
      return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
    }
    res = delivery_map_[deliver_id_];
    delivery_map_.erase(deliver_id_++);
  }

  cv_delivery_.notify_all();

  // extract every blob field from blob data
  std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> res_with_blobs;
  for (auto iter = res->begin(); iter != res->end(); iter++) {
    std::map<std::string, std::vector<uint8_t>> key_with_blob_fields;
    auto shard_column = GetShardColumn();
    auto schema = shard_header_->GetSchemas();  // current, we only support 1 schema yet
    auto blob_fields = schema[0]->GetBlobFields();
    for (auto blob_field : blob_fields) {
      const unsigned char *data = nullptr;
      std::unique_ptr<unsigned char[]> data_ptr;
      uint64_t n_bytes = 0;
      mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
      uint64_t column_data_type_size = 1;
      std::vector<int64_t> column_shape;
      if (shard_column->GetColumnValueByName(blob_field, std::get<0>(*iter), std::get<1>(*iter), &data, &data_ptr,
                                             &n_bytes, &column_data_type, &column_data_type_size,
                                             &column_shape) != Status::OK()) {
        MS_LOG(ERROR) << "[Internal ERROR] Failed to extract blob fields from blob data";
        return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
      }
      key_with_blob_fields[blob_field] = std::vector<uint8_t>(data, data + n_bytes);
    }

    res_with_blobs.emplace_back(std::move(key_with_blob_fields), std::move(std::get<1>(*iter)));
  }

  return res_with_blobs;
}

Status ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id,
                                std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
  if (interrupt_) {
    return Status::OK();
  }
  RETURN_IF_NOT_OK_MR(ConsumerOneTask(task_id, consumer_id, task_content_ptr));
  return Status::OK();
}

Status ShardReader::UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
                                   std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr) {
  RETURN_UNEXPECTED_IF_NULL_MR(blob_data_ptr);
  auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
  auto blob_fields = GetBlobFields().second;
  for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
    if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) {
      continue;
    }
    const unsigned char *data = nullptr;
    std::unique_ptr<unsigned char[]> data_ptr;
    uint64_t n_bytes = 0;
    RETURN_IF_NOT_OK_MR(
      shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes));
    if (data == nullptr) {
      data = reinterpret_cast<const unsigned char *>(data_ptr.get());
    }
    std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
    (*blob_data_ptr)->push_back(column);
  }
  return Status::OK();
}

Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
  *total_blob_size = total_blob_size_;
  return Status::OK();
}

void ShardReader::Reset() {
  {
    std::lock_guard<std::mutex> lck(mtx_delivery_);
    sample_id_position_ = 0;
    deliver_id_ = 0;
  }
  cv_delivery_.notify_all();
}

void ShardReader::ShuffleTask() {
  // exist shuffle and distributed sampler in ops, skip shuffle
  bool has_sharding = false;
  for (const auto &op : operators_) {
    if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
      has_sharding = true;
    }
  }
  for (const auto &op : operators_) {
    if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
      auto s = (*op)(tasks_);
      if (s.IsError()) {
        MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
      }
    } else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
      auto s = (*op)(tasks_);
      if (s.IsError()) {
        MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
      }
    }
  }
  if (load_mode_ != kSlow) {
    if (tasks_.permutation_.empty()) {
      tasks_.MakePerm();
    }
  } else {
    tasks_.generator_ids_.ResetShardIndexAndID();
  }
}

const std::vector<int64_t> *ShardReader::GetSampleIds() {
  // return const reference to private sample id list.
  return &(this->tasks_.sample_ids_);
}

LoadMode ShardReader::GetLoadMode() const { return load_mode_; }

void ShardReader::GetSampleIdsByRandomAccess() {
  if (all_sampler_ids_.empty()) {
    all_sampler_ids_.reserve(num_rows_);
    if (GetLoadMode() != mindrecord::LoadMode::kSlow) {
      auto vector_ids = GetSampleIds();
      all_sampler_ids_ = *vector_ids;
    } else {
      while (true) {
        auto next_sample_ids = GetNextSampleIds();
        if (next_sample_ids.empty()) {
          break;
        } else {
          all_sampler_ids_.insert(all_sampler_ids_.end(), next_sample_ids.begin(), next_sample_ids.end());
        }
      }
    }
  }
}

Status ShardReader::GetMappedIndex(size_t index, size_t *mapped_index) {
  GetSampleIdsByRandomAccess();
  if (all_sampler_ids_.empty()) {
    // Note, all_sampler_ids_.empty() is okay and will just give no sample ids.
    MS_LOG(EXCEPTION) << "[Internal ERROR] Init Sampler failed as sample_ids is empty, here ShardReader did not "
                         "provide a valid sample ids vector via MindRecordOp.";
  }
  if (index >= all_sampler_ids_.size()) {
    MS_LOG(EXCEPTION) << "Input index is not within the required interval of [0, " << all_sampler_ids_.size() - 1
                      << "], but got " << index << ".";
  }
  *mapped_index = all_sampler_ids_[index];
  return Status::OK();
}

std::vector<int64_t> ShardReader::GetNextSampleIds() { return tasks_.GetNextSampleIds(); }
}  // namespace mindrecord
}  // namespace mindspore
